#!/usr/bin/python

import hashlib
import logging
import logging.handlers
import os
import re
import sys
import time

log = logging.getLogger('auth-openvpn')

def authenticate():
    # annoy attackers
    time.sleep(1)

    path = sys.argv[1]
    with file(path, 'rb') as f:
        user = f.readline(8192)
        assert user.endswith('\n')
        user = user[:-1]
        assert user
        secret = f.readline(8192)
        assert secret.endswith('\n')
        secret = secret[:-1]
        assert secret

    # From openvpn(8):
    #
    # To protect against a client passing a maliciously formed username or
    # password string, the username string must consist only of these
    # characters: alphanumeric, underbar ('_'), dash ('-'), dot ('.'), or
    # at ('@'). The password string can consist of any printable
    # characters except for CR or LF. Any illegal characters in either the
    # username or password string will be converted to underbar ('_').
    #
    # We'll just redo that quickly for usernames, to ensure they are safe.

    user = re.sub(r'[^a-zA-Z0-9_.@-]', '_', user)

    def find_user(wanted):
        with file('{{ openvpn_data_dir }}/users') as f:
            for line in f:
                assert line.endswith('\n')
                line = line[:-1]
                if line.startswith("#") or len(line) == 0:
                    continue
                (username, salt, correct) = line.split(' ', 2)
                if username == wanted:
                    return (salt, correct)

        # these will never match
        log.error('User not found: %r', wanted)
        salt = 'not-found'
        correct = 64*'x'
        return (salt, correct)

    (salt, correct) = find_user(user)

    inner = hashlib.new('sha256')
    inner.update(salt)
    inner.update(secret)
    outer = hashlib.new('sha256')
    outer.update(inner.digest())
    outer.update(salt)
    attempt = outer.hexdigest()

    if attempt != correct:
        log.error('{prog}: invalid auth for user {user!r}.'.format(prog=os.path.basename(sys.argv[0]), user=user))
        sys.exit(1)

def main():
    handler = logging.handlers.SysLogHandler(
        address='/dev/log',
        facility=logging.handlers.SysLogHandler.LOG_DAEMON,
        )
    fmt = logging.Formatter('%(name)s: %(message)s')
    handler.setFormatter(fmt)
    logging.basicConfig()
    root = logging.getLogger('')
    root.addHandler(handler)
    log.setLevel(logging.INFO)

    try:
        authenticate()
    except SystemExit:
        raise
    except:
        log.exception('Unhandled error: ')
        raise

if __name__ == '__main__':
    sys.exit(main())
